You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

Anime Tagger eva02_large_patch14_448.dbv4-full

Model Details

  • Model Type: Multilabel Image classification / feature backbone
  • Model Stats:
    • Params: 316.8M
    • FLOPs / MACs: 620.9G / 310.1G
    • Image size: train = 448 x 448, test = 448 x 448
  • Dataset: animetimm/danbooru-wdtagger-v4-w640-ws-full
    • Tags Count: 12476
      • General (#0) Tags Count: 9225
      • Character (#4) Tags Count: 3247
      • Rating (#9) Tags Count: 4

Results

# Macro@0.40 (F1/MCC/P/R) Micro@0.40 (F1/MCC/P/R) Macro@Best (F1/P/R)
Validation 0.528 / 0.537 / 0.600 / 0.503 0.678 / 0.678 / 0.693 / 0.664 ---
Test 0.529 / 0.538 / 0.601 / 0.503 0.679 / 0.678 / 0.694 / 0.665 0.574 / 0.580 / 0.591
  • Macro/Micro@0.40 means the metrics on the threshold 0.40.
  • Macro@Best means the mean metrics on the tag-level thresholds on each tags, which should have the best F1 scores.

Thresholds

Category Name Alpha Threshold Micro@Thr (F1/P/R) Macro@0.40 (F1/P/R) Macro@Best (F1/P/R)
0 general 1 0.37 0.667 / 0.666 / 0.668 0.399 / 0.493 / 0.367 0.453 / 0.452 / 0.483
4 character 1 0.57 0.922 / 0.951 / 0.895 0.896 / 0.909 / 0.889 0.917 / 0.943 / 0.895
9 rating 1 0.4 0.824 / 0.784 / 0.868 0.829 / 0.800 / 0.862 0.831 / 0.806 / 0.859
  • Micro@Thr means the metrics on the category-level suggested thresholds, which are listed in the table above.
  • Macro@0.40 means the metrics on the threshold 0.40.
  • Macro@Best means the metrics on the tag-level thresholds on each tags, which should have the best F1 scores.

For tag-level thresholds, you can find them in selected_tags.csv.

How to Use

We provided a sample image for our code samples, you can find it here.

Use TIMM And Torch

Install dghs-imgutils, timm and other necessary requirements with the following command

pip install 'dghs-imgutils>=0.17.0' torch huggingface_hub timm pillow pandas

After that you can load this model with timm library, and use it for train, validation and test, with the following code

import json

import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.preprocess import create_torchvision_transforms
from timm import create_model

repo_id = 'animetimm/eva02_large_patch14_448.dbv4-full'
model = create_model(f'hf-hub:{repo_id}', pretrained=True)
model.eval()

with open(hf_hub_download(repo_id=repo_id, repo_type='model', filename='preprocess.json'), 'r') as f:
    preprocessor = create_torchvision_transforms(json.load(f)['test'])
# Compose(
#     PadToSize(size=(512, 512), interpolation=bilinear, background_color=white)
#     Resize(size=(448, 448), interpolation=bicubic, max_size=None, antialias=True)
#     CenterCrop(size=[448, 448])
#     MaybeToTensor()
#     Normalize(mean=tensor([0.4815, 0.4578, 0.4082]), std=tensor([0.2686, 0.2613, 0.2758]))
# )

image = load_image('https://huggingface.co/animetimm/eva02_large_patch14_448.dbv4-full/resolve/main/sample.webp')
input_ = preprocessor(image).unsqueeze(0)
# input_, shape: torch.Size([1, 3, 448, 448]), dtype: torch.float32
with torch.no_grad():
    output = model(input_)
    prediction = torch.sigmoid(output)[0]
# output, shape: torch.Size([1, 12476]), dtype: torch.float32
# prediction, shape: torch.Size([12476]), dtype: torch.float32

df_tags = pd.read_csv(
    hf_hub_download(repo_id=repo_id, repo_type='model', filename='selected_tags.csv'),
    keep_default_na=False
)
tags = df_tags['name']
mask = prediction.numpy() >= df_tags['best_threshold']
print(dict(zip(tags[mask].tolist(), prediction[mask].tolist())))
# {'sensitive': 0.6976025700569153,
#  '1girl': 0.9952899217605591,
#  'solo': 0.9671481847763062,
#  'looking_at_viewer': 0.7711699604988098,
#  'blush': 0.7974982261657715,
#  'smile': 0.8849270939826965,
#  'short_hair': 0.817248523235321,
#  'long_sleeves': 0.5171797275543213,
#  'brown_hair': 0.6675055623054504,
#  'dress': 0.6894800662994385,
#  'closed_mouth': 0.35917922854423523,
#  'sitting': 0.7595945000648499,
#  'purple_eyes': 0.8275928497314453,
#  'flower': 0.8742285966873169,
#  'braid': 0.8496974110603333,
#  'blunt_bangs': 0.39164724946022034,
#  'tears': 0.8591281771659851,
#  'floral_print': 0.44396182894706726,
#  'crying': 0.4951671063899994,
#  'plant': 0.758698046207428,
#  'blue_flower': 0.5387876629829407,
#  'tearing_up': 0.11903537809848785,
#  'crying_with_eyes_open': 0.3073916733264923,
#  'crown_braid': 0.7725721001625061,
#  'potted_plant': 0.8286207318305969,
#  'flower_pot': 0.6531336307525635,
#  'happy_tears': 0.3884831964969635,
#  'pavement': 0.2094476968050003,
#  'wiping_tears': 0.6769278645515442,
#  'holding_flower_pot': 0.12655559182167053}

Use ONNX Model For Inference

Install dghs-imgutils with the following command

pip install 'dghs-imgutils>=0.17.0'

Use multilabel_timm_predict function with the following code

from imgutils.generic import multilabel_timm_predict

general, character, rating = multilabel_timm_predict(
    'https://huggingface.co/animetimm/eva02_large_patch14_448.dbv4-full/resolve/main/sample.webp',
    repo_id='animetimm/eva02_large_patch14_448.dbv4-full',
    fmt=('general', 'character', 'rating'),
)

print(general)
# {'1girl': 0.9952900409698486,
#  'solo': 0.9671480655670166,
#  'smile': 0.8849270343780518,
#  'flower': 0.8742280602455139,
#  'tears': 0.8591268062591553,
#  'braid': 0.8496923446655273,
#  'potted_plant': 0.8286197185516357,
#  'purple_eyes': 0.8275918364524841,
#  'short_hair': 0.8172485828399658,
#  'blush': 0.7974982857704163,
#  'crown_braid': 0.772567629814148,
#  'looking_at_viewer': 0.7711694240570068,
#  'sitting': 0.759594738483429,
#  'plant': 0.7586977481842041,
#  'dress': 0.6894786357879639,
#  'wiping_tears': 0.6769236326217651,
#  'brown_hair': 0.6675049662590027,
#  'flower_pot': 0.6531318426132202,
#  'blue_flower': 0.5387848615646362,
#  'long_sleeves': 0.5171791315078735,
#  'crying': 0.4951639473438263,
#  'floral_print': 0.44396066665649414,
#  'blunt_bangs': 0.39164483547210693,
#  'happy_tears': 0.3884800672531128,
#  'closed_mouth': 0.3591785430908203,
#  'crying_with_eyes_open': 0.30738943815231323,
#  'pavement': 0.20944759249687195,
#  'holding_flower_pot': 0.12655416131019592,
#  'tearing_up': 0.11903449892997742}
print(character)
# {}
print(rating)
# {'sensitive': 0.6976030468940735}

For further information, see documentation of function multilabel_timm_predict.

Downloads last month
0
Safetensors
Model size
317M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for animetimm/eva02_large_patch14_448.dbv4-full

Quantized
(2)
this model

Dataset used to train animetimm/eva02_large_patch14_448.dbv4-full